{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "118UKH5bWCGa"
   },
   "source": [
    "# DALL·E mini - Embedding Retrain Preparation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We'll start with the dalle-mini model for faster experimentation."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "K6CxW2o42f-w"
   },
   "outputs": [],
   "source": [
    "DALLE_MODEL = \"dalle-mini/dalle-mini/mini-1:v0\"  # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
    "DALLE_COMMIT_ID = None\n",
    "\n",
    "# # dalle-mega\n",
    "# DALLE_MODEL = \"dalle-mini/dalle-mini/mega-1-fp16:latest\"  # can be wandb artifact or 🤗 Hub or local folder or google bucket\n",
    "# DALLE_COMMIT_ID = None"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "Yv-aR3t4Oe5v"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import jax\n",
    "import jax.numpy as jnp\n",
    "\n",
    "# check how many devices are available\n",
    "jax.local_device_count()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Load model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We load the model twice to keep a copy of the original parameters."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "92zYmvsQ38vL"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact mini-1:v0, 1673.43MB. 7 files... Done. 0:0:1.2\n",
      "tcmalloc: large alloc 1751343104 bytes == 0x56011a2c0000 @  0x7f143aaa9680 0x7f143aaca824 0x5600d248253b 0x5600d24c30ba 0x5600d2599a58 0x5600d24f548d 0x5600d23cf328 0x5600d25af66d 0x5600d24f5825 0x5600d24532da 0x5600d24eafe3 0x5600d24ec709 0x5600d249a1ea 0x5600d252be7a 0x5600d24eafe3 0x5600d24ec709 0x5600d245273d 0x5600d24eafe3 0x5600d2597a7c 0x5600d24ebdbb 0x5600d25ce33e 0x5600d24f5571 0x5600d2452088 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d24f5f94 0x5600d24532da 0x5600d24ebbe4\n",
      "\u001b[34m\u001b[1mwandb\u001b[0m: Downloading large artifact mini-1:v0, 1673.43MB. 7 files... Done. 0:0:1.2\n",
      "tcmalloc: large alloc 1751343104 bytes == 0x56011a2c0000 @  0x7f143aaa9680 0x7f143aaca824 0x5600d248253b 0x5600d24c30ba 0x5600d2599a58 0x5600d24f548d 0x5600d23cf328 0x5600d25af66d 0x5600d24f5825 0x5600d24532da 0x5600d24eafe3 0x5600d24ec709 0x5600d249a1ea 0x5600d252be7a 0x5600d24eafe3 0x5600d24ec709 0x5600d245273d 0x5600d24eafe3 0x5600d2597a7c 0x5600d24ebdbb 0x5600d25ce33e 0x5600d24f5571 0x5600d2452088 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d252f0fc 0x5600d24e07cb 0x5600d24f5f94 0x5600d24532da 0x5600d24ebbe4\n"
     ]
    }
   ],
   "source": [
    "# Load model\n",
    "from dalle_mini import DalleBart, DalleBartProcessor\n",
    "\n",
    "model, params = DalleBart.from_pretrained(\n",
    "    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
    ")\n",
    "\n",
    "_, params_original = DalleBart.from_pretrained(\n",
    "    DALLE_MODEL, revision=DALLE_COMMIT_ID, dtype=jnp.float16, _do_init=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Model surgery: remove layers to be retrained"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's take a look at the params tree."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "437833712"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(x.size for x in jax.tree_leaves(params))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{\n",
      "  \"lm_head\": {\n",
      "    \"kernel\": [\n",
      "      1024,\n",
      "      16385\n",
      "    ]\n",
      "  },\n",
      "  \"model\": {\n",
      "    \"decoder\": {\n",
      "      \"embed_positions\": {\n",
      "        \"embedding\": [\n",
      "          256,\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"embed_tokens\": {\n",
      "        \"embedding\": [\n",
      "          16385,\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"final_ln\": {\n",
      "        \"bias\": [\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"layernorm_embedding\": {\n",
      "        \"bias\": [\n",
      "          1024\n",
      "        ],\n",
      "        \"scale\": [\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"layers\": {\n",
      "        \"FlaxBartDecoderLayers\": {\n",
      "          \"FlaxBartAttention_0\": {\n",
      "            \"k_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"out_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"q_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"v_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            }\n",
      "          },\n",
      "          \"FlaxBartAttention_1\": {\n",
      "            \"k_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"out_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"q_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"v_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            }\n",
      "          },\n",
      "          \"GLU_0\": {\n",
      "            \"Dense_0\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                2730\n",
      "              ]\n",
      "            },\n",
      "            \"Dense_1\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                2730\n",
      "              ]\n",
      "            },\n",
      "            \"Dense_2\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                2730,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"LayerNorm_0\": {\n",
      "              \"bias\": [\n",
      "                12,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"LayerNorm_1\": {\n",
      "              \"bias\": [\n",
      "                12,\n",
      "                2730\n",
      "              ]\n",
      "            }\n",
      "          },\n",
      "          \"LayerNorm_0\": {\n",
      "            \"bias\": [\n",
      "              12,\n",
      "              1024\n",
      "            ]\n",
      "          },\n",
      "          \"LayerNorm_1\": {\n",
      "            \"bias\": [\n",
      "              12,\n",
      "              1024\n",
      "            ],\n",
      "            \"scale\": [\n",
      "              12,\n",
      "              1024\n",
      "            ]\n",
      "          },\n",
      "          \"LayerNorm_2\": {\n",
      "            \"bias\": [\n",
      "              12,\n",
      "              1024\n",
      "            ]\n",
      "          },\n",
      "          \"LayerNorm_3\": {\n",
      "            \"bias\": [\n",
      "              12,\n",
      "              1024\n",
      "            ],\n",
      "            \"scale\": [\n",
      "              12,\n",
      "              1024\n",
      "            ]\n",
      "          }\n",
      "        }\n",
      "      }\n",
      "    },\n",
      "    \"encoder\": {\n",
      "      \"embed_positions\": {\n",
      "        \"embedding\": [\n",
      "          64,\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"embed_tokens\": {\n",
      "        \"embedding\": [\n",
      "          50264,\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"final_ln\": {\n",
      "        \"bias\": [\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"layernorm_embedding\": {\n",
      "        \"bias\": [\n",
      "          1024\n",
      "        ],\n",
      "        \"scale\": [\n",
      "          1024\n",
      "        ]\n",
      "      },\n",
      "      \"layers\": {\n",
      "        \"FlaxBartEncoderLayers\": {\n",
      "          \"FlaxBartAttention_0\": {\n",
      "            \"k_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"out_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"q_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"v_proj\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                1024\n",
      "              ]\n",
      "            }\n",
      "          },\n",
      "          \"GLU_0\": {\n",
      "            \"Dense_0\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                2730\n",
      "              ]\n",
      "            },\n",
      "            \"Dense_1\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                1024,\n",
      "                2730\n",
      "              ]\n",
      "            },\n",
      "            \"Dense_2\": {\n",
      "              \"kernel\": [\n",
      "                12,\n",
      "                2730,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"LayerNorm_0\": {\n",
      "              \"bias\": [\n",
      "                12,\n",
      "                1024\n",
      "              ]\n",
      "            },\n",
      "            \"LayerNorm_1\": {\n",
      "              \"bias\": [\n",
      "                12,\n",
      "                2730\n",
      "              ]\n",
      "            }\n",
      "          },\n",
      "          \"LayerNorm_0\": {\n",
      "            \"bias\": [\n",
      "              12,\n",
      "              1024\n",
      "            ]\n",
      "          },\n",
      "          \"LayerNorm_1\": {\n",
      "            \"bias\": [\n",
      "              12,\n",
      "              1024\n",
      "            ],\n",
      "            \"scale\": [\n",
      "              12,\n",
      "              1024\n",
      "            ]\n",
      "          }\n",
      "        }\n",
      "      }\n",
      "    }\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "import json\n",
    "\n",
    "tree = jax.tree_map(lambda x: x.shape, params)\n",
    "print(json.dumps(tree, indent=2))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We will remove or reinitialize:\n",
    "- `lm_head`\n",
    "- `model.decoder.embed_positions`\n",
    "- `model.decoder.embed_tokens`\n",
    "- `model.decoder.final_ln`\n",
    "- `model.decoder.layernorm_embedding`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "del params[\"lm_head\"]\n",
    "for layer in [\"embed_positions\", \"embed_tokens\", \"final_ln\", \"layernorm_embedding\"]:\n",
    "    del params[\"model\"][\"decoder\"][layer]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'model': {'decoder': {'layers': {'FlaxBartDecoderLayers': {'FlaxBartAttention_0': {'k_proj': {'kernel': (12,\n",
       "        1024,\n",
       "        1024)},\n",
       "      'out_proj': {'kernel': (12, 1024, 1024)},\n",
       "      'q_proj': {'kernel': (12, 1024, 1024)},\n",
       "      'v_proj': {'kernel': (12, 1024, 1024)}},\n",
       "     'FlaxBartAttention_1': {'k_proj': {'kernel': (12, 1024, 1024)},\n",
       "      'out_proj': {'kernel': (12, 1024, 1024)},\n",
       "      'q_proj': {'kernel': (12, 1024, 1024)},\n",
       "      'v_proj': {'kernel': (12, 1024, 1024)}},\n",
       "     'GLU_0': {'Dense_0': {'kernel': (12, 1024, 2730)},\n",
       "      'Dense_1': {'kernel': (12, 1024, 2730)},\n",
       "      'Dense_2': {'kernel': (12, 2730, 1024)},\n",
       "      'LayerNorm_0': {'bias': (12, 1024)},\n",
       "      'LayerNorm_1': {'bias': (12, 2730)}},\n",
       "     'LayerNorm_0': {'bias': (12, 1024)},\n",
       "     'LayerNorm_1': {'bias': (12, 1024), 'scale': (12, 1024)},\n",
       "     'LayerNorm_2': {'bias': (12, 1024)},\n",
       "     'LayerNorm_3': {'bias': (12, 1024), 'scale': (12, 1024)}}}},\n",
       "  'encoder': {'embed_positions': {'embedding': (64, 1024)},\n",
       "   'embed_tokens': {'embedding': (50264, 1024)},\n",
       "   'final_ln': {'bias': (1024,)},\n",
       "   'layernorm_embedding': {'bias': (1024,), 'scale': (1024,)},\n",
       "   'layers': {'FlaxBartEncoderLayers': {'FlaxBartAttention_0': {'k_proj': {'kernel': (12,\n",
       "        1024,\n",
       "        1024)},\n",
       "      'out_proj': {'kernel': (12, 1024, 1024)},\n",
       "      'q_proj': {'kernel': (12, 1024, 1024)},\n",
       "      'v_proj': {'kernel': (12, 1024, 1024)}},\n",
       "     'GLU_0': {'Dense_0': {'kernel': (12, 1024, 2730)},\n",
       "      'Dense_1': {'kernel': (12, 1024, 2730)},\n",
       "      'Dense_2': {'kernel': (12, 2730, 1024)},\n",
       "      'LayerNorm_0': {'bias': (12, 1024)},\n",
       "      'LayerNorm_1': {'bias': (12, 2730)}},\n",
       "     'LayerNorm_0': {'bias': (12, 1024)},\n",
       "     'LayerNorm_1': {'bias': (12, 1024), 'scale': (12, 1024)}}}}}}"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.tree_map(lambda x: x.shape, params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "404012016"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "sum(x.size for x in jax.tree_leaves(params))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Reinitialize layers"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "We save a checkpoint and reload it again. It does not automatically reinitialize the missing keys, but it sets `_missing_keys` appropriately so we can initialize them later. We could do the same by simply setting that property ourselves, but I'll refrain from doing so because it's a private implementation detail."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "trimmed_checkpoint = \"mini-trimmed\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "tcmalloc: large alloc 1610424320 bytes == 0x5632d11c8000 @  0x7f95ccad0680 0x7f95ccaf0bdd 0x7f95be99e29f 0x7f95be9a7750 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a4fc4 0x7f95be9a571e 0x5630fb630f94 0x5630fb58e2da 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb626be4 0x5630fb58d088 0x5630fb625fe3 0x5630fb627709 0x5630fb58d73d 0x5630fb625fe3 0x5630fb6d2a7c 0x5630fb626dbb 0x5630fb70933e 0x5630fb630571\n",
      "tcmalloc: large alloc 3231449088 bytes == 0x56333119a000 @  0x7f95ccad0680 0x7f95ccaf0bdd 0x7f95be99e29f 0x7f95be9a7750 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a4fc4 0x7f95be9a571e 0x5630fb630f94 0x5630fb58e2da 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb626be4 0x5630fb58d088 0x5630fb625fe3 0x5630fb627709 0x5630fb58d73d 0x5630fb625fe3 0x5630fb6d2a7c 0x5630fb626dbb 0x5630fb70933e 0x5630fb630571\n"
     ]
    }
   ],
   "source": [
    "model.save_pretrained(trimmed_checkpoint, params=params)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "The checkpoint mini-trimmed is missing required keys: {('model', 'decoder', 'embed_tokens', 'embedding'), ('lm_head', 'kernel'), ('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'final_ln', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layernorm_embedding', 'bias')}. Make sure to call model.init_weights to initialize the missing weights.\n",
      "Some weights of DalleBart were not initialized from the model checkpoint at mini-trimmed and are newly initialized: {('model', 'decoder', 'embed_tokens', 'embedding'), ('lm_head', 'kernel'), ('model', 'decoder', 'embed_positions', 'embedding'), ('model', 'decoder', 'final_ln', 'bias'), ('model', 'decoder', 'layernorm_embedding', 'scale'), ('model', 'decoder', 'layernorm_embedding', 'bias')}\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    }
   ],
   "source": [
    "model, params = DalleBart.from_pretrained(\n",
    "    trimmed_checkpoint, revision=None, dtype=jnp.float16, _do_init=False\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{('lm_head', 'kernel'),\n",
       " ('model', 'decoder', 'embed_positions', 'embedding'),\n",
       " ('model', 'decoder', 'embed_tokens', 'embedding'),\n",
       " ('model', 'decoder', 'final_ln', 'bias'),\n",
       " ('model', 'decoder', 'layernorm_embedding', 'bias'),\n",
       " ('model', 'decoder', 'layernorm_embedding', 'scale')}"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model._missing_keys"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "params_reinit = model.init_weights(model.key, model.input_shape, params=params)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Verification"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The structure should be the same as the original `params` dict. Re-initialized layers should have different parameters, but existing layers should be the same."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FrozenDict({\n",
       "    lm_head: {\n",
       "        kernel: (1024, 16385),\n",
       "    },\n",
       "    model: {\n",
       "        decoder: {\n",
       "            embed_positions: {\n",
       "                embedding: (256, 1024),\n",
       "            },\n",
       "            embed_tokens: {\n",
       "                embedding: (16385, 1024),\n",
       "            },\n",
       "            final_ln: {\n",
       "                bias: (1024,),\n",
       "            },\n",
       "            layernorm_embedding: {\n",
       "                bias: (1024,),\n",
       "                scale: (1024,),\n",
       "            },\n",
       "            layers: {\n",
       "                FlaxBartDecoderLayers: {\n",
       "                    FlaxBartAttention_0: {\n",
       "                        k_proj: {\n",
       "                            kernel: (12, 1024, 1024),\n",
       "                        },\n",
       "                        out_proj: {\n",
       "                            kernel: (12, 1024, 1024),\n",
       "                        },\n",
       "                        q_proj: {\n",
       "                            kernel: (12, 1024, 1024),\n",
       "                        },\n",
       "                        v_proj: {\n",
       "                            kernel: (12, 1024, 1024),\n",
       "                        },\n",
       "                    },\n",
       "                    FlaxBartAttention_1: {\n",
       "                        k_proj: {\n",
       "                            kernel: (12, 1024, 1024),\n",
       "                        },\n",
       "                        out_proj: {\n",
       "                            kernel: (12, 1024, 1024),\n",
       "                        },\n",
       "                        q_proj: {\n",
       "                            kernel: (12, 1024, 1024),\n",
       "                        },\n",
       "                        v_proj: {\n",
       "                            kernel: (12, 1024, 1024),\n",
       "                        },\n",
       "                    },\n",
       "                    GLU_0: {\n",
       "                        Dense_0: {\n",
       "                            kernel: (12, 1024, 2730),\n",
       "                        },\n",
       "                        Dense_1: {\n",
       "                            kernel: (12, 1024, 2730),\n",
       "                        },\n",
       "                        Dense_2: {\n",
       "                            kernel: (12, 2730, 1024),\n",
       "                        },\n",
       "                        LayerNorm_0: {\n",
       "                            bias: (12, 1024),\n",
       "                        },\n",
       "                        LayerNorm_1: {\n",
       "                            bias: (12, 2730),\n",
       "                        },\n",
       "                    },\n",
       "                    LayerNorm_0: {\n",
       "                        bias: (12, 1024),\n",
       "                    },\n",
       "                    LayerNorm_1: {\n",
       "                        bias: (12, 1024),\n",
       "                        scale: (12, 1024),\n",
       "                    },\n",
       "                    LayerNorm_2: {\n",
       "                        bias: (12, 1024),\n",
       "                    },\n",
       "                    LayerNorm_3: {\n",
       "                        bias: (12, 1024),\n",
       "                        scale: (12, 1024),\n",
       "                    },\n",
       "                },\n",
       "            },\n",
       "        },\n",
       "        encoder: {\n",
       "            embed_positions: {\n",
       "                embedding: (64, 1024),\n",
       "            },\n",
       "            embed_tokens: {\n",
       "                embedding: (50264, 1024),\n",
       "            },\n",
       "            final_ln: {\n",
       "                bias: (1024,),\n",
       "            },\n",
       "            layernorm_embedding: {\n",
       "                bias: (1024,),\n",
       "                scale: (1024,),\n",
       "            },\n",
       "            layers: {\n",
       "                FlaxBartEncoderLayers: {\n",
       "                    FlaxBartAttention_0: {\n",
       "                        k_proj: {\n",
       "                            kernel: (12, 1024, 1024),\n",
       "                        },\n",
       "                        out_proj: {\n",
       "                            kernel: (12, 1024, 1024),\n",
       "                        },\n",
       "                        q_proj: {\n",
       "                            kernel: (12, 1024, 1024),\n",
       "                        },\n",
       "                        v_proj: {\n",
       "                            kernel: (12, 1024, 1024),\n",
       "                        },\n",
       "                    },\n",
       "                    GLU_0: {\n",
       "                        Dense_0: {\n",
       "                            kernel: (12, 1024, 2730),\n",
       "                        },\n",
       "                        Dense_1: {\n",
       "                            kernel: (12, 1024, 2730),\n",
       "                        },\n",
       "                        Dense_2: {\n",
       "                            kernel: (12, 2730, 1024),\n",
       "                        },\n",
       "                        LayerNorm_0: {\n",
       "                            bias: (12, 1024),\n",
       "                        },\n",
       "                        LayerNorm_1: {\n",
       "                            bias: (12, 2730),\n",
       "                        },\n",
       "                    },\n",
       "                    LayerNorm_0: {\n",
       "                        bias: (12, 1024),\n",
       "                    },\n",
       "                    LayerNorm_1: {\n",
       "                        bias: (12, 1024),\n",
       "                        scale: (12, 1024),\n",
       "                    },\n",
       "                },\n",
       "            },\n",
       "        },\n",
       "    },\n",
       "})"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "jax.tree_map(lambda x: x.shape, params_reinit)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "FrozenDict({\n",
       "    embedding: DeviceArray([[ 0.00582082, -0.04113895,  0.00918633, ..., -0.00530822,\n",
       "                   0.01297319,  0.02720674],\n",
       "                 [ 0.03540739,  0.03676804, -0.02924041, ...,  0.00163185,\n",
       "                  -0.01938273, -0.02105987],\n",
       "                 [ 0.00478452, -0.03438002, -0.0024974 , ..., -0.03892584,\n",
       "                   0.01721252,  0.02605445],\n",
       "                 ...,\n",
       "                 [ 0.02495495,  0.00559381, -0.01588043, ...,  0.01393714,\n",
       "                  -0.01824111, -0.02007291],\n",
       "                 [ 0.00983252, -0.00180564, -0.01686333, ..., -0.01001718,\n",
       "                   0.01886345, -0.00393983],\n",
       "                 [-0.03589988, -0.00455565,  0.00076276, ..., -0.02145007,\n",
       "                  -0.00180798, -0.0133148 ]], dtype=float32),\n",
       "})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "params_reinit[\"model\"][\"decoder\"][\"embed_positions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DeviceArray(-0.09320386, dtype=float32),\n",
       " DeviceArray(0.08769083, dtype=float32))"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding_new = params_reinit[\"model\"][\"decoder\"][\"embed_positions\"][\"embedding\"]\n",
    "embedding_new.min(), embedding_new.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'embedding': DeviceArray([[ 0.03459017, -0.0065838 , -0.11748601, ..., -0.01451578,\n",
       "               -0.03927238, -0.00266367],\n",
       "              [-0.03116009,  0.00438436,  0.02691377, ..., -0.02886203,\n",
       "               -0.01095741, -0.02649871],\n",
       "              [-0.03568491, -0.0086962 ,  0.01851564, ..., -0.04736514,\n",
       "                0.05310551, -0.01648099],\n",
       "              ...,\n",
       "              [-0.02454913,  0.03746822, -0.02269235, ...,  0.03377315,\n",
       "                0.003004  ,  0.04975331],\n",
       "              [-0.05145862,  0.04472217,  0.11103845, ...,  0.04581303,\n",
       "                0.02850476,  0.00554514],\n",
       "              [-0.01037806,  0.00281054, -0.0485299 , ..., -0.03325456,\n",
       "               -0.0058979 ,  0.01733843]], dtype=float32)}"
      ]
     },
     "execution_count": 17,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "params_original[\"model\"][\"decoder\"][\"embed_positions\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(DeviceArray(-0.25866088, dtype=float32),\n",
       " DeviceArray(0.08769083, dtype=float32))"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "embedding_original = params_original[\"model\"][\"decoder\"][\"embed_positions\"][\"embedding\"]\n",
    "embedding_original.min(), embedding_new.max()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert jnp.allclose(embedding_new, embedding_original).item() == False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "lm_head_original = params_original[\"lm_head\"][\"kernel\"]\n",
    "lm_head_reinit = params_reinit[\"lm_head\"][\"kernel\"]\n",
    "assert jnp.allclose(lm_head_reinit, lm_head_original).item() == False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [],
   "source": [
    "assert jnp.allclose(\n",
    "    params_reinit[\"model\"][\"encoder\"][\"layers\"][\"FlaxBartEncoderLayers\"][\n",
    "        \"FlaxBartAttention_0\"\n",
    "    ][\"k_proj\"][\"kernel\"],\n",
    "    params_original[\"model\"][\"encoder\"][\"layers\"][\"FlaxBartEncoderLayers\"][\n",
    "        \"FlaxBartAttention_0\"\n",
    "    ][\"k_proj\"][\"kernel\"],\n",
    ").item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Save checkpoint for retrain"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Finally, we save the resulting model to retrain those layers."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "checkpoint_dir = \"mini-reinit\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "tcmalloc: large alloc 3367796736 bytes == 0x5633f235a000 @  0x7f95ccad0680 0x7f95ccaf0bdd 0x7f95be99e29f 0x7f95be9a7750 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a87b4 0x7f95be9a4fc4 0x7f95be9a571e 0x5630fb630f94 0x5630fb58e2da 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb625fe3 0x5630fb626d24 0x5630fb58d73d 0x5630fb626be4 0x5630fb58d088 0x5630fb625fe3 0x5630fb627709 0x5630fb58d73d 0x5630fb625fe3 0x5630fb6d2a7c 0x5630fb626dbb 0x5630fb70933e 0x5630fb630571\n"
     ]
    }
   ],
   "source": [
    "model.save_pretrained(checkpoint_dir, params=params_reinit)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Upload checkpoint to W&B"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "import wandb\n",
    "from pathlib import Path"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mpcuenq\u001b[0m (\u001b[33mdalle-mini\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "Tracking run with wandb version 0.12.21"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Run data is saved locally in <code>/home/pedro/code/dalle-mini/dalle-mini/tools/train/wandb/run-20220722_105625-2v9szi3q</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Syncing run <strong><a href=\"https://wandb.ai/dalle-mini/dalle-mini/runs/2v9szi3q\" target=\"_blank\">astral-durian-2957</a></strong> to <a href=\"https://wandb.ai/dalle-mini/dalle-mini\" target=\"_blank\">Weights & Biases</a> (<a href=\"https://wandb.me/run\" target=\"_blank\">docs</a>)<br/>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src=\"https://wandb.ai/dalle-mini/dalle-mini/runs/2v9szi3q?jupyter=true\" style=\"border:none;width:100%;height:420px;display:none;\"></iframe>"
      ],
      "text/plain": [
       "<wandb.sdk.wandb_run.Run at 0x7f959a563b80>"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wandb.init(\n",
    "    entity=\"dalle-mini\",\n",
    "    project=\"dalle-mini\",\n",
    "    job_type=\"Seq2Seq\",\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "artifact = wandb.Artifact(\n",
    "    name=f\"model-{wandb.run.id}\",\n",
    "    type=\"DalleBart_model\",\n",
    "    metadata={\"embeddings\": \"reset\"},\n",
    ")\n",
    "\n",
    "for filename in [\"config.json\", \"flax_model.msgpack\"]:\n",
    "    artifact.add_file(f\"{Path(checkpoint_dir) / filename}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<wandb.sdk.wandb_artifacts.Artifact at 0x7f95984c3fd0>"
      ]
     },
     "execution_count": 27,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "wandb.run.log_artifact(artifact)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "Waiting for W&B process to finish... <strong style=\"color:green\">(success).</strong>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "VBox(children=(Label(value='1670.207 MB of 1670.207 MB uploaded (0.000 MB deduped)\\r'), FloatProgress(value=1.…"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Synced <strong style=\"color:#cdcd00\">astral-durian-2957</strong>: <a href=\"https://wandb.ai/dalle-mini/dalle-mini/runs/2v9szi3q\" target=\"_blank\">https://wandb.ai/dalle-mini/dalle-mini/runs/2v9szi3q</a><br/>Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "Find logs at: <code>./wandb/run-20220722_105625-2v9szi3q/logs</code>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "wandb.finish()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "----"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "collapsed_sections": [],
   "include_colab_link": true,
   "machine_shape": "hm",
   "name": "DALL·E mini - Inference pipeline.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
